# encoding: utf-8


import os
import h5py
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import xception
from keras.applications.xception import Xception
from keras.layers.core import Lambda
from keras.layers import Input, GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Dropout
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import adam, Adam
from sklearn.utils import shuffle
from keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from loss_history import LossHistory
from keras import backend as K
K.clear_session()


# -------------------------------------
# 定义参数
# -------------------------------------
EPOCHS = 15
BATCH_SIZE = 6  # batch size
IMAGE_SIZE = (512, 512)  # 图片尺寸
XCEPTION_NO_TRAINABLE_LAYERS = 0
LEARNING_RATE = 1e-4  # 学习率
EARLY_STOPPING_PATIENCE = 3
REDUCE_LR_PATIENCE = 3
CLASS_NUM = 2  # 类别数
DROP_RATE = 0.5


# -------------------------------------
# 文件路径
# -------------------------------------
TRAIN_DATA_PATH = './data/train/'
VALID_DATA_PATH = './data/valid/'
RESULT_PATH = './result/'
MODEL_PATH = './model/'
OUTPUT_PATH = './output/'
MODEL_NAME = 'xecption-finetune-model.hdf5'


def my_preprocess(x):
    x = x / 255.0
    return x


def finetuneModel():
    train_gen = ImageDataGenerator()
    valid_gen = ImageDataGenerator()
    train_generator = train_gen.flow_from_directory(TRAIN_DATA_PATH, IMAGE_SIZE, shuffle=True, batch_size=BATCH_SIZE, color_mode='grayscale')
    valid_generator = valid_gen.flow_from_directory(VALID_DATA_PATH, IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode='grayscale')
    inputs = Input((*IMAGE_SIZE, 1))
    x = Lambda(my_preprocess)(inputs)
    # 不包含最后的输出层
    base_model = Xception(input_tensor=x, weights=None, include_top=False, pooling=None)
    x = GlobalAveragePooling2D(name='my_global_average_pooling_layer_1')(base_model.output)
    x = Dropout(DROP_RATE, name='my_dropout_layer_1')(x)
    predictions = Dense(CLASS_NUM, activation='softmax', name='my_dense_layer_1')(x)
    model = Model(base_model.input, predictions)
    plot_model(model, to_file=os.path.join(RESULT_PATH, 'xception.png'), show_shapes=True)


    # set trainable layer
    # 固定前86层的参数，只调整86层以后的参数，包括第86层
    for layer in model.layers[:XCEPTION_NO_TRAINABLE_LAYERS]:
        layer.trainable = False
    for layer in model.layers[XCEPTION_NO_TRAINABLE_LAYERS:]:
        layer.trainable = True

    model.summary()

    # check
    for i, layer in enumerate(model.layers):
        print('{}: {}, {}'.format(i, layer.name, layer.trainable))
    print('='*100)
    layers = zip(range(len(model.layers)), [x.name for x in model.layers])
    for layer_num, layer_name in layers:
        print('{}: {}'.format(layer_num + 1, layer_name))

    # check point
    check_point = ModelCheckpoint(monitor='val_loss',
                                  filepath=os.path.join(MODEL_PATH, MODEL_NAME),
                                  verbose=1,
                                  save_best_only=True,
                                  save_weights_only=False,
                                  mode='auto')

    # early stoppiing
    early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOPPING_PATIENCE, verbose=0, mode='auto')

    # 创建一个 LossHistory 实例
    history = LossHistory()

    # compile
    model.compile(optimizer=adam(lr=LEARNING_RATE), loss='binary_crossentropy', metrics=['accuracy'])

    # fit
    model.fit_generator(
        train_generator,
        steps_per_epoch=train_generator.samples // BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=valid_generator,
        validation_steps=valid_generator.samples // BATCH_SIZE,
        callbacks=[check_point, early_stopping, history]
    )

    # 绘制 loss 曲线和 batch 曲线
    history.loss_plot('batch', os.path.join(RESULT_PATH, 'xception_loss_batch.png'))
    history.acc_plot('batch', os.path.join(RESULT_PATH, 'xception_acc_batch.png'))
    history.loss_plot('epoch', os.path.join(RESULT_PATH, 'xception_loss_epoch.png'))
    history.acc_plot('epoch', os.path.join(RESULT_PATH, 'xception_acc_epoch.png'))


def extractFeatures():
    BATCH_SIZE = 4
    train_gen = ImageDataGenerator()
    valid_gen = ImageDataGenerator()
    train_generator = train_gen.flow_from_directory(TRAIN_DATA_PATH, IMAGE_SIZE, shuffle=False, batch_size=BATCH_SIZE, color_mode='grayscale')
    valid_generator = valid_gen.flow_from_directory(VALID_DATA_PATH, IMAGE_SIZE, shuffle=False, batch_size=BATCH_SIZE, color_mode='grayscale')
    inputs = Input((*IMAGE_SIZE, 1))
    x = Lambda(my_preprocess)(inputs)
    base_model = Xception(input_tensor=x, weights=None, include_top=False, pooling=None)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))
    # 加载模型参数
    model.load_weights(os.path.join(MODEL_PATH, MODEL_NAME), by_name=True)
    train_features = model.predict_generator(train_generator, steps=len(train_generator.filenames) // BATCH_SIZE, use_multiprocessing=True, workers=8, verbose=1)
    valid_features = model.predict_generator(valid_generator, steps=len(valid_generator.filenames) // BATCH_SIZE, use_multiprocessing=True, workers=8, verbose=1)
    with h5py.File(os.path.join(OUTPUT_PATH, 'xception-finetune-output.hdf5'), 'w') as h:
        h.create_dataset('X_train', data=train_features)
        h.create_dataset('y_train', data=train_generator.classes[:((train_generator.samples // BATCH_SIZE) * BATCH_SIZE)])
        h.create_dataset('X_val', data=valid_features)
        h.create_dataset('y_val', data=valid_generator.classes[:((valid_generator.samples // BATCH_SIZE) * BATCH_SIZE)])


if __name__ == '__main__':
    finetuneModel()
    extractFeatures()
